# ------------------------------------------------------------------------------
# Trains split ensemble models on validation set
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 07_predict_test.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(Matrix)
library(tidyverse)
library(glue)
library(optparse)
library(here) # here() relative filepaths
library(testit) # assert()
library(OneR) # bin()

u <- modules::use(here::here("lib", "util.R"))
temp <- here("code", "08_train_split_model", "temp")

# Helpers ----------------------------------------------------------------------
ntile_within <- function(x, n_tiles, which_x, showWarnings = TRUE) {
  assert("Indices of which_x in x", all(which_x <= length(x)))
  if(any(is.na(x[which_x]))){
    warning("X contains NA values; establishing tiles on non-NA values.")
  }
  cutpoints <- quantile(x[which_x], 1:(n_tiles-1)/ n_tiles, na.rm = TRUE)
  # add min x and max x as end points
  cutpoints <- c(min(x, na.rm = TRUE), cutpoints, max(x, na.rm = TRUE))
  x_tiled <- cut(
    x, breaks = cutpoints, labels = 1:n_tiles, include.lowest = T
  ) %>%
  as.integer
  assert("Num NAs same in x, x_tiled", sum(is.na(x)) == sum(is.na(x_tiled)))
  return(x_tiled)
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))
overnight_lab <- ""
test_exclude <- readRDS(glue(paths$analysis$test_cohort)) %>%
  select(ed_enc_id, exclude)
visits <- readRDS(file.path(temp, "subscores_test_set.rds")) %>%
  u$safe_left_join(test_exclude)

ensemble_1 <- readRDS(file.path(temp, "ensemble__stent_or_cabg_010_day__tested__1.rds"))
ensemble_2 <- readRDS(file.path(temp, "ensemble__stent_or_cabg_010_day__tested__2.rds"))

# Predict ----------------------------------------------------------------------
message("Predicting ensemble in test set...")
visits$p__ensemble__1 <- predict(ensemble_1, visits)
visits$p__ensemble__2 <- predict(ensemble_2, visits)

# Risk Tiles -------------------------------------------------------------------
message("Getting risk tiles in tested...")
for(i in c(1,2)){
  for(tile in c(5, 10, 100)){
    risk <- glue("p__ensemble__{i}")
    message("Risk variable: ", risk)
    message("Tiles: ", tile)
    tile_lab <- str_pad(tile, 3, pad = "0")

    keep_x <- !visits$exclude #& visits$test_010_day
    keep_x_tested <- !visits$exclude & visits$test_010_day
    tile_var <- glue("tile_{tile_lab}_ensemble__{i}")
    tile_var_tested <- glue("{tile_var}__tested")

    visits[[tile_var]] <- ntile_within(
      visits[[risk]], tile, which_x = which(keep_x)
    )
    visits[[tile_var_tested]] <- ntile_within(
      visits[[risk]], tile, which_x = which(keep_x_tested)
    )

  }
}

# Save -------------------------------------------------------------------------
message("Saving scored test cohort...")
save_fp <- file.path(temp, "scored_test_cohort.rds")
write_rds(visits, save_fp)

message("Done.")
